Skip to content

Conversation

@Mikep86
Copy link
Contributor

@Mikep86 Mikep86 commented Aug 4, 2025

This is a POC implementation of CCS support for the semantic query when ccs_minimize_roundtrips=true.

It implements:

  • semantic query multi-index handling (adapted from Support using the semantic query across multiple inference IDs #120755)
  • semantic query CCS support when ccs_minimize_roundtrips=true
  • Detection for when ccs_minimize_roundtrips=false
  • Integration tests demonstrating the high-level functionality
  • A way to reuse local embeddings on remote clusters when compatible

Comment on lines +109 to +110
SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(INFERENCE_FIELD, "foo");
queryBuilder.setModelRegistrySupplier(() -> modelRegistry);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This admittedly a hacky way to pass the model registry to the semantic query, but I was looking for a way to do it that didn't involve a lot of refactoring. The proper way to do this is likely through the constructor.

Comment on lines -226 to 322

String inferenceId = getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
SetOnce<InferenceServiceResults> inferenceResultsSupplier = new SetOnce<>();
boolean noInferenceResults = false;
if (inferenceId != null) {
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
TaskType.ANY,
inferenceId,
null,
null,
null,
List.of(query),
Map.of(),
InputType.INTERNAL_SEARCH,
null,
false
);

queryRewriteContext.registerAsyncAction(
(client, listener) -> executeAsyncWithOrigin(
client,
ML_ORIGIN,
InferenceAction.INSTANCE,
inferenceRequest,
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
inferenceResultsSupplier.set(inferenceResponse.getResults());
l.onResponse(null);
})
)
);
MapEmbeddingsProvider currentEmbeddingsProvider;
if (embeddingsProvider != null) {
if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) {
currentEmbeddingsProvider = mapEmbeddingsProvider;
} else {
throw new IllegalStateException("Current embeddings provider should be a MapEmbeddingsProvider");
}
} else {
// The inference ID can be null if either the field name or index name(s) are invalid (or both).
// If this happens, we set the "no inference results" flag to true so the rewrite process can continue.
// Invalid index names will be handled in the transport layer, when the query is sent to the shard.
// Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings.
noInferenceResults = true;
currentEmbeddingsProvider = new MapEmbeddingsProvider();
}

return new SemanticQueryBuilder(this, noInferenceResults ? null : inferenceResultsSupplier, null, noInferenceResults);
boolean modified = false;
if (queryRewriteContext.hasAsyncActions() == false) {
ModelRegistry modelRegistry = modelRegistrySupplier.get();
if (modelRegistry == null) {
throw new IllegalStateException("Model registry has not been set");
}

Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
for (String inferenceId : inferenceIds) {
MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId);
InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings);

if (currentEmbeddingsProvider.getEmbeddings(inferenceEndpointKey) == null) {
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
TaskType.ANY,
inferenceId,
null,
null,
null,
List.of(query),
Map.of(),
InputType.INTERNAL_SEARCH,
null,
false
);

queryRewriteContext.registerAsyncAction(
(client, listener) -> executeAsyncWithOrigin(
client,
ML_ORIGIN,
InferenceAction.INSTANCE,
inferenceRequest,
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
currentEmbeddingsProvider.addEmbeddings(
inferenceEndpointKey,
validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId)
);
l.onResponse(null);
})
)
);

modified = true;
}
}
}

return modified ? new SemanticQueryBuilder(this, currentEmbeddingsProvider, false) : this;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic demonstrates a way to reuse embeddings cross-cluster, when they are compatible. For the sake of this POC I chose to use the combination of inference ID + minimal service settings to qualify inference endpoints as equal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could make this even more simple. For example, apply a warning if just the inference ids are different. It is just a warning after all, that there is some potential detected difference here. This may not to be perfect as far as calculating the model registry to detect different models. We could also consider a flag for lenient mode to suppress warnings if people intentionally want to use different inference IDs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood how there could be a gap here in detecting compatible embeddings. If we want to be more conservative here, we could use something like cluster name + inference ID to identify embeddings in the map. That would mean no embedding reuse cross-cluster though.

Setting a warning doesn't work for CCS though as warning headers are not transmitted back to the primary cluster.

@Mikep86 Mikep86 requested review from jimczi and kderusso August 6, 2025 18:33
Copy link
Member

@kderusso kderusso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice POC! I've left some high level comments on functionality as this is still a POC.

Rewriteable.rewriteAndFetch(
original,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null),
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null, original.isCcsMinimizeRoundtrips()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there precedent for having CCS-specific knobs like this in generic search code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but there's a good case to be made why this is necessary. The CCS mode affects the query rewrite cycle, thus we need a way to know about it within that context. Info is passed to query rewrite via QueryRewriteContext, thus this implementation.

import java.io.IOException;
import java.util.Objects;

public class InferenceEndpointKey implements Writeable {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this conceptually, but nitpicky - I would like to find a better name for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I wasn't focusing too much on names, just proving out functionality


ModelRegistry modelRegistry = modelRegistrySupplier.get();
if (modelRegistry == null) {
throw new IllegalStateException("Model registry has not been set");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to test this, to make sure it actually should always return a 500/trigger a serverless alert, similar to some other alerts we've been seeing for semantic queries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of those "should never happen in production" errors. If it does, it's a symptom of an upstream problem that we should be alerted to.

);
MapEmbeddingsProvider currentEmbeddingsProvider;
if (embeddingsProvider != null) {
if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this check should be necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another one of those "this should never fail in production" cases. If we get here, it means we're performing coordinator node rewrite. That means that we're performing inference, either on a local or remote cluster.

If we're on a local cluster, we can assume that this node built the initial query and thus the embeddings provider should be a MapEmbeddingsProvider.

If we're on a remote cluster, we can assume that the the primary (i.e. local) cluster allows semantic queries to perform CCS, which is directly correlated to usage of MapEmbeddingsProvider.

Either way, we need the representation to be MapEmbeddingsProvider so that we can call addEmbeddings later, hence this check.

Comment on lines -226 to 322

String inferenceId = getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
SetOnce<InferenceServiceResults> inferenceResultsSupplier = new SetOnce<>();
boolean noInferenceResults = false;
if (inferenceId != null) {
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
TaskType.ANY,
inferenceId,
null,
null,
null,
List.of(query),
Map.of(),
InputType.INTERNAL_SEARCH,
null,
false
);

queryRewriteContext.registerAsyncAction(
(client, listener) -> executeAsyncWithOrigin(
client,
ML_ORIGIN,
InferenceAction.INSTANCE,
inferenceRequest,
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
inferenceResultsSupplier.set(inferenceResponse.getResults());
l.onResponse(null);
})
)
);
MapEmbeddingsProvider currentEmbeddingsProvider;
if (embeddingsProvider != null) {
if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) {
currentEmbeddingsProvider = mapEmbeddingsProvider;
} else {
throw new IllegalStateException("Current embeddings provider should be a MapEmbeddingsProvider");
}
} else {
// The inference ID can be null if either the field name or index name(s) are invalid (or both).
// If this happens, we set the "no inference results" flag to true so the rewrite process can continue.
// Invalid index names will be handled in the transport layer, when the query is sent to the shard.
// Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings.
noInferenceResults = true;
currentEmbeddingsProvider = new MapEmbeddingsProvider();
}

return new SemanticQueryBuilder(this, noInferenceResults ? null : inferenceResultsSupplier, null, noInferenceResults);
boolean modified = false;
if (queryRewriteContext.hasAsyncActions() == false) {
ModelRegistry modelRegistry = modelRegistrySupplier.get();
if (modelRegistry == null) {
throw new IllegalStateException("Model registry has not been set");
}

Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
for (String inferenceId : inferenceIds) {
MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId);
InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings);

if (currentEmbeddingsProvider.getEmbeddings(inferenceEndpointKey) == null) {
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
TaskType.ANY,
inferenceId,
null,
null,
null,
List.of(query),
Map.of(),
InputType.INTERNAL_SEARCH,
null,
false
);

queryRewriteContext.registerAsyncAction(
(client, listener) -> executeAsyncWithOrigin(
client,
ML_ORIGIN,
InferenceAction.INSTANCE,
inferenceRequest,
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
currentEmbeddingsProvider.addEmbeddings(
inferenceEndpointKey,
validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId)
);
l.onResponse(null);
})
)
);

modified = true;
}
}
}

return modified ? new SemanticQueryBuilder(this, currentEmbeddingsProvider, false) : this;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could make this even more simple. For example, apply a warning if just the inference ids are different. It is just a warning after all, that there is some potential detected difference here. This may not to be perfect as far as calculating the model registry to detect different models. We could also consider a flag for lenient mode to suppress warnings if people intentionally want to use different inference IDs.

)
);
} else if (inferenceResultsList.size() > 1) {
// The inference call should truncate if the query is too large.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that's the case for all models? For example we warn in our docs that OpenAI will error if BYO chunks are too large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember that we are handling query-time inference here, which will never chunk. We should always get back one inference result. If we get back more, something has gone horribly wrong in the Inference API that we want to know about, hence this check.

This comment may not be fully technically correct in that some providers may error instead of truncate on huge input. However, in that case, we will still get back only one inference result, it will just be an instance of ErrorInferenceResults.

@Mikep86
Copy link
Contributor Author

Mikep86 commented Aug 25, 2025

Superceded by #133466

@Mikep86 Mikep86 closed this Aug 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants